{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G3MMAcssHTML"
      },
      "source": [
        "<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
        "<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200\" />"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3_lX1k54KKrx"
      },
      "source": [
        "Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Gr4W9nspKGtb"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q5_nIe-8gdJV"
      },
      "source": [
        "# Generate PaliGemma output with Keras\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/paligemma/inference-with-keras\"><img src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" />View on ai.google.dev</a>\n",
        "</td>\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/inference-with-keras.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "</td>\n",
        "<td>\n",
        "<a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/inference-with-keras.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9hhIuS9sEKHx"
      },
      "source": [
        "PaliGemma models have *multimodal* capabilities, allowing you to generate output using both text and image input data. You can use image data with these models to provide additional context for your requests, or use the model to  analyze the content of images. This tutorial shows you how to use PaliGemma with Keras to can analyze images and answer questions about them."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9JaII1xxfbfz"
      },
      "source": [
        "## What's in this notebook\n",
        "\n",
        "This notebook uses PaliGemma with Keras and shows you how to:\n",
        "\n",
        "* Install Keras and the required dependencies\n",
        "* Download `PaliGemmaCausalLM`, a pre-trained PaliGemma variant for causal visual language modeling, and use it to create a model\n",
        "* Test the model's ability to infer information about supplied images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bf7AUi02fcPL"
      },
      "source": [
        "## Before you begin\n",
        "\n",
        "Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with Keras, but basic knowledge about Keras is helpful when reading through the example code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "380it9kIhzmm"
      },
      "source": [
        "## Setup\n",
        "\n",
        "The following sections explain the preliminary steps for getting a notebook to use a PaliGemma model, including model access, getting an API key, and configuring the notebook runtime."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6uN0EpPJh7t5"
      },
      "source": [
        "### Get access to PaliGemma\n",
        "\n",
        "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n",
        "\n",
        "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n",
        "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2/) and click **Request Access**.\n",
        "1. Complete the consent form and accept the terms and conditions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uz_pBNrWiDqe"
      },
      "source": [
        "### Configure your API key\n",
        "\n",
        "To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.\n",
        "\n",
        "To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
        "\n",
        "Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wUB4JE0Hlxce"
      },
      "source": [
        "### Select the runtime\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the PaliGemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, click the **▾ (Additional connection options)** dropdown menu.\n",
        "1. Select **Change runtime type**.\n",
        "1. Under **Hardware accelerator**, select **T4 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mIvv2Yo3lycQ"
      },
      "source": [
        "### Set environment variables\n",
        "\n",
        "Set the environment variables for `KAGGLE_USERNAME`, `KAGGLE_KEY`, and `KERAS_BACKEND`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rdgwLyZLQBkP"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Set up environmental variables\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
        "os.environ[\"KERAS_BACKEND\"] = \"jax\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4a3Q4VCLljR9"
      },
      "source": [
        "### Install Keras\n",
        "\n",
        "Run the below cell to install Keras."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DoYMMytAaMRJ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m792.1/792.1 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -U -q keras-nlp keras-hub kagglehub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y2Y7BRtRgfvG"
      },
      "source": [
        "### Import dependencies and configure Keras\n",
        "\n",
        "Install the dependencies needed for this notebook and configure Keras' backend. You'll also set Keras to use `bfloat16` so that the framework uses less memory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MHECpBe6LE7y"
      },
      "outputs": [],
      "source": [
        "import keras\n",
        "import keras_hub\n",
        "import numpy as np\n",
        "import PIL\n",
        "import requests\n",
        "import io\n",
        "import matplotlib\n",
        "import re\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.patches as patches\n",
        "from PIL import Image\n",
        "\n",
        "keras.config.set_floatx(\"bfloat16\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X-LE2E1uiSpP"
      },
      "source": [
        "## Load the model\n",
        "\n",
        "Now that you've set everything up, you can download the pre-trained model and create some utility methods to help your model generate its responses.\n",
        "In this step, you download a model using `PaliGemmaCausalLM` from Keras Hub. This class helps you manage and run the causal visual language model structure of PaliGemma. A *causal visual language model* predicts the next token based on previous tokens. Keras Hub provides implementations of many popular [model architectures](https://keras.io/keras_hub/api/models/).\n",
        "\n",
        "\n",
        "Create the model using the `from_preset` method and print its summary. This process will take about a minute to complete."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "abNuIP8D_9At"
      },
      "outputs": [],
      "source": [
        "paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset(\"kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224\")\n",
        "paligemma.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FBsWvKEvoGMe"
      },
      "source": [
        "## Create utility methods\n",
        "\n",
        "To help you generate responses from your model, create two utility methods:\n",
        "\n",
        "*   **`crop_and_resize`:** Helper method for `read_img`. This method crops and resizes the image to the passed in size so that the final image is resized without skewing the proportions of the image.\n",
        "*   **`read_img`:** Helper method for `read_img_from_url`. This method is what actually opens the image, resizes it so that it fits in the model's constraints, and puts it into an array that can be interpreted by the model.\n",
        "*   **`read_img_from_url`:** Takes in an image via a valid URL. You need this method to pass the image to the model.\n",
        "\n",
        "You'll use `read_img_from_url` in the next step of this notebook.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S6_XQjhpvXiG"
      },
      "outputs": [],
      "source": [
        "def crop_and_resize(image, target_size):\n",
        "    width, height = image.size\n",
        "    source_size = min(image.size)\n",
        "    left = width // 2 - source_size // 2\n",
        "    top = height // 2 - source_size // 2\n",
        "    right, bottom = left + source_size, top + source_size\n",
        "    return image.resize(target_size, box=(left, top, right, bottom))\n",
        "\n",
        "def read_image(url, target_size):\n",
        "    contents = io.BytesIO(requests.get(url).content)\n",
        "    image = PIL.Image.open(contents)\n",
        "    image = crop_and_resize(image, target_size)\n",
        "    image = np.array(image)\n",
        "    # Remove alpha channel if necessary.\n",
        "    if image.shape[2] == 4:\n",
        "        image = image[:, :, :3]\n",
        "    return image\n",
        "\n",
        "def parse_bbox_and_labels(detokenized_output: str):\n",
        "  matches = re.finditer(\n",
        "      '<loc(?P<y0>\\d\\d\\d\\d)><loc(?P<x0>\\d\\d\\d\\d)><loc(?P<y1>\\d\\d\\d\\d)><loc(?P<x1>\\d\\d\\d\\d)>'\n",
        "      ' (?P<label>.+?)( ;|$)',\n",
        "      detokenized_output,\n",
        "  )\n",
        "  labels, boxes = [], []\n",
        "  fmt = lambda x: float(x) / 1024.0\n",
        "  for m in matches:\n",
        "    d = m.groupdict()\n",
        "    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])\n",
        "    labels.append(d['label'])\n",
        "  return np.array(boxes), np.array(labels)\n",
        "\n",
        "def display_boxes(image, boxes, labels, target_image_size):\n",
        "  h, l = target_size\n",
        "  fig, ax = plt.subplots()\n",
        "  ax.imshow(image)\n",
        "  for i in range(boxes.shape[0]):\n",
        "      y, x, y2, x2 = (boxes[i]*h)\n",
        "      width = x2 - x\n",
        "      height = y2 - y\n",
        "      # Create a Rectangle patch\n",
        "      rect = patches.Rectangle((x, y),\n",
        "                               width,\n",
        "                               height,\n",
        "                               linewidth=1,\n",
        "                               edgecolor='r',\n",
        "                               facecolor='none')\n",
        "      # Add label\n",
        "      plt.text(x, y, labels[i], color='red', fontsize=12)\n",
        "      # Add the patch to the Axes\n",
        "      ax.add_patch(rect)\n",
        "\n",
        "  plt.show()\n",
        "\n",
        "def display_segment_output(image, bounding_box, segment_mask, target_image_size):\n",
        "    # Initialize a full mask with the target size\n",
        "    full_mask = np.zeros(target_image_size, dtype=np.uint8)\n",
        "    target_width, target_height = target_image_size\n",
        "\n",
        "    for bbox, mask in zip(bounding_box, segment_mask):\n",
        "        y1, x1, y2, x2 = bbox\n",
        "        x1 = int(x1 * target_width)\n",
        "        y1 = int(y1 * target_height)\n",
        "        x2 = int(x2 * target_width)\n",
        "        y2 = int(y2 * target_height)\n",
        "\n",
        "        # Ensure mask is 2D before converting to Image\n",
        "        if mask.ndim == 3:\n",
        "            mask = mask.squeeze(axis=-1)\n",
        "        mask = Image.fromarray(mask)\n",
        "        mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)\n",
        "        mask = np.array(mask)\n",
        "        binary_mask = (mask > 0.5).astype(np.uint8)\n",
        "\n",
        "\n",
        "        # Place the binary mask onto the full mask\n",
        "        full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)\n",
        "    cmap = plt.get_cmap('jet')\n",
        "    colored_mask = cmap(full_mask / 1.0)\n",
        "    colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)\n",
        "    if isinstance(image, Image.Image):\n",
        "        image = np.array(image)\n",
        "    blended_image = image.copy()\n",
        "    mask_indices = full_mask > 0\n",
        "    alpha = 0.5\n",
        "\n",
        "    for c in range(3):\n",
        "        blended_image[:, :, c] = np.where(mask_indices,\n",
        "                                          (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],\n",
        "                                          image[:, :, c])\n",
        "\n",
        "    fig, ax = plt.subplots()\n",
        "    ax.imshow(blended_image)\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AeVUHA_zP8ZF"
      },
      "source": [
        "## Generate output\n",
        "\n",
        "After loading the model and creating utility methods, you can prompt the model with image and text data to generate a responses. PaliGemma models are trained with specific prompt syntax for specific tasks, such as `answer`, `caption`, and `detect`. For more information about PaliGemma prompt task syntax, see [PaliGemma prompt and system instructions](https://ai.google.dev/gemma/docs/paligemma/prompt-system-instructions##prompt_task_syntax).\n",
        "\n",
        "Prepare an image for use in a generation prompt by using the following code to load a test image into an object:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GtPeUkGJFvXS"
      },
      "outputs": [],
      "source": [
        "target_size = (224, 224)\n",
        "image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n",
        "cow_image = read_image(image_url, target_size)\n",
        "matplotlib.pyplot.imshow(cow_image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R3FE63iMqb2G"
      },
      "source": [
        "### Answer in a specific language\n",
        "\n",
        "The following example code shows how to prompt the PaliGemma model for information about an object appearing in a provided image. This example uses the `answer {lang}` syntax and shows additional questions in other languages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WQVmhpr6qdNH"
      },
      "outputs": [],
      "source": [
        "prompt = 'answer en where is the cow standing?\\n'\n",
        "# prompt = 'svar no hvor står kuen?\\n'\n",
        "# prompt = 'answer fr quelle couleur est le ciel?\\n'\n",
        "# prompt = 'responda pt qual a cor do animal?\\n'\n",
        "\n",
        "output = paligemma.generate(\n",
        "    inputs={\n",
        "        \"images\": cow_image,\n",
        "        \"prompts\": prompt,\n",
        "    }\n",
        ")\n",
        "print(output)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tMPv8tfISvif"
      },
      "source": [
        "Note: Prompts that use PaliGemma prompt command syntax must end with an \"`\\n`\" character."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eUgPcUGDpsdx"
      },
      "source": [
        "### Use `detect` prompt\n",
        "\n",
        "The following example code uses the `detect` prompt syntax to locate an object in the provided image. The code uses the previously defined `parse_bbox_and_labels()` and `display_boxes()` functions to interpret the model output and display the generated bounding boxes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "noRIZtqnpzNn"
      },
      "outputs": [],
      "source": [
        "prompt = 'detect cow\\n'\n",
        "output = paligemma.generate(\n",
        "    inputs={\n",
        "        \"images\": cow_image,\n",
        "        \"prompts\": prompt,\n",
        "    }\n",
        ")\n",
        "boxes, labels = parse_bbox_and_labels(output)\n",
        "display_boxes(cow_image, boxes, labels, target_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TET0Bz10mzxD"
      },
      "source": [
        "### Use `segment` prompt\n",
        "\n",
        "The following example code uses the `segment` prompt syntax to locate the area of an image occupied by an object. It uses the Google `big_vision` library to interpret the model output and generate a mask for the segmented object.\n",
        "\n",
        "Before getting started, install the `big_vision` library and its dependencies, as shown in this code example:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tuxyIMAvsM4B"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "# TPUs with\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n",
        "\n",
        "# Fetch big_vision repository if python doesn't know about it and install\n",
        "# dependencies needed for this notebook.\n",
        "if not os.path.exists(\"big_vision_repo\"):\n",
        "  !git clone --quiet --branch=main --depth=1 \\\n",
        "     https://github.com/google-research/big_vision big_vision_repo\n",
        "\n",
        "# Append big_vision code to python import path\n",
        "if \"big_vision_repo\" not in sys.path:\n",
        "  sys.path.append(\"big_vision_repo\")\n",
        "\n",
        "\n",
        "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n",
        "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "18S4uHWqutps"
      },
      "source": [
        "For this segmentation example, load and prepare a different image that includes a cat."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Al6NedG4uNhp"
      },
      "outputs": [],
      "source": [
        "cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)\n",
        "matplotlib.pyplot.imshow(cat)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VRJYikeXu1As"
      },
      "source": [
        "Here is a function to help parse the segment output from PaliGemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ryCMfGVuxyd"
      },
      "outputs": [],
      "source": [
        "import  big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval\n",
        "reconstruct_masks = segeval.get_reconstruct_masks('oi')\n",
        "def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:\n",
        "  matches = re.finditer(\n",
        "      '<loc(?P<y0>\\d\\d\\d\\d)><loc(?P<x0>\\d\\d\\d\\d)><loc(?P<y1>\\d\\d\\d\\d)><loc(?P<x1>\\d\\d\\d\\d)>'\n",
        "      + ''.join(f'<seg(?P<s{i}>\\d\\d\\d)>' for i in range(16)),\n",
        "      detokenized_output,\n",
        "  )\n",
        "  boxes, segs = [], []\n",
        "  fmt_box = lambda x: float(x) / 1024.0\n",
        "  for m in matches:\n",
        "    d = m.groupdict()\n",
        "    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])\n",
        "    segs.append([int(d[f's{i}']) for i in range(16)])\n",
        "  return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QITD66qJvCTO"
      },
      "source": [
        "Query PaliGemma to segment the cat in the image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fB7to-J4u5zY"
      },
      "outputs": [],
      "source": [
        "prompt = 'segment cat\\n'\n",
        "output = paligemma.generate(\n",
        "    inputs={\n",
        "        \"images\": cat,\n",
        "        \"prompts\": prompt,\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XZeu6-bovFvz"
      },
      "source": [
        "Visualize the generated mask from PaliGemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GcjOvoPbvAI-"
      },
      "outputs": [],
      "source": [
        "bboxes, seg_masks = parse_segments(output)\n",
        "display_segment_output(cat, bboxes, seg_masks, target_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jI9YAgLuGyAb"
      },
      "source": [
        "### Batch prompts\n",
        "\n",
        "You can provide more than one prompt command within a single prompt as a batch of instructions. The following example demonstrates how to structure your prompt text to provide multiple instructions.\n",
        "\n",
        "Important: Each prompt command must end with an \"\\n\" character, as shown."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UZle4sJP6YwB"
      },
      "outputs": [],
      "source": [
        "prompts = [\n",
        "    'answer en where is the cow standing?\\n',\n",
        "    'answer en what color is the cow?\\n',\n",
        "    'describe en\\n',\n",
        "    'detect cow\\n',\n",
        "    'segment cow\\n',\n",
        "]\n",
        "images = [cow_image, cow_image, cow_image, cow_image, cow_image]\n",
        "outputs = paligemma.generate(\n",
        "    inputs={\n",
        "        \"images\": images,\n",
        "        \"prompts\": prompts,\n",
        "    }\n",
        ")\n",
        "for output in outputs:\n",
        "    print(output)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "inference-with-keras.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
